/*
 * Copyright 2025-2025 the original author or authors.
 */

package io.modelcontextprotocol;

import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;

import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.server.McpServer;
import io.modelcontextprotocol.server.McpServerFeatures;
import io.modelcontextprotocol.server.McpSyncServer;
import io.modelcontextprotocol.server.McpSyncServerExchange;
import io.modelcontextprotocol.server.TestUtil;
import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.ProtocolVersions;
import io.modelcontextprotocol.utils.McpTestRequestRecordingExchangeFilterFunction;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.netty.DisposableServer;
import reactor.netty.http.server.HttpServer;

import org.springframework.http.HttpMethod;
import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions;
import org.springframework.web.reactive.function.server.ServerResponse;

import static org.assertj.core.api.Assertions.assertThat;

class WebFluxStreamableHttpVersionNegotiationIntegrationTests {

	private static final int PORT = TestUtil.findAvailablePort();

	private DisposableServer httpServer;

	private final McpTestRequestRecordingExchangeFilterFunction recordingFilterFunction = new McpTestRequestRecordingExchangeFilterFunction();

	private final McpSchema.Tool toolSpec = McpSchema.Tool.builder()
		.name("test-tool")
		.description("return the protocol version used")
		.build();

	private final BiFunction<McpSyncServerExchange, McpSchema.CallToolRequest, McpSchema.CallToolResult> toolHandler = (
			exchange, request) -> new McpSchema.CallToolResult(
					exchange.transportContext().get("protocol-version").toString(), null);

	private final WebFluxStreamableServerTransportProvider mcpStreamableServerTransportProvider = WebFluxStreamableServerTransportProvider
		.builder()
		.contextExtractor(req -> McpTransportContext
			.create(Map.of("protocol-version", req.headers().firstHeader("MCP-protocol-version"))))
		.build();

	private final McpSyncServer mcpServer = McpServer.sync(mcpStreamableServerTransportProvider)
		.capabilities(McpSchema.ServerCapabilities.builder().tools(false).build())
		.tools(new McpServerFeatures.SyncToolSpecification(toolSpec, null, toolHandler))
		.build();

	@BeforeEach
	void setUp() {
		RouterFunction<ServerResponse> filteredRouter = mcpStreamableServerTransportProvider.getRouterFunction()
			.filter(recordingFilterFunction);

		HttpHandler httpHandler = RouterFunctions.toHttpHandler(filteredRouter);

		ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler);

		this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow();
	}

	@AfterEach
	public void after() {
		if (httpServer != null) {
			httpServer.disposeNow();
		}
		if (mcpServer != null) {
			mcpServer.close();
		}
	}

	@Test
	void usesLatestVersion() {
		var client = McpClient
			.sync(WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT))
				.build())
			.requestTimeout(Duration.ofHours(10))
			.build();

		client.initialize();

		McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("test-tool", Map.of()));

		var calls = recordingFilterFunction.getCalls();
		assertThat(calls).filteredOn(c -> !c.body().contains("\"method\":\"initialize\""))
			// GET /mcp ; POST notification/initialized ; POST tools/call
			.hasSize(3)
			.map(McpTestRequestRecordingExchangeFilterFunction.Call::headers)
			.allSatisfy(headers -> assertThat(headers).containsEntry("mcp-protocol-version",
					ProtocolVersions.MCP_2025_06_18));

		assertThat(response).isNotNull();
		assertThat(response.content()).hasSize(1)
			.first()
			.extracting(McpSchema.TextContent.class::cast)
			.extracting(McpSchema.TextContent::text)
			.isEqualTo(ProtocolVersions.MCP_2025_06_18);
		mcpServer.close();
	}

	@Test
	void usesServerSupportedVersion() {
		var transport = WebClientStreamableHttpTransport
			.builder(WebClient.builder().baseUrl("http://localhost:" + PORT))
			.supportedProtocolVersions(List.of(ProtocolVersions.MCP_2025_06_18, "2263-03-18"))
			.build();
		var client = McpClient.sync(transport).requestTimeout(Duration.ofHours(10)).build();

		client.initialize();

		McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("test-tool", Map.of()));

		var calls = recordingFilterFunction.getCalls();
		// Initialize tells the server the Client's latest supported version
		// FIXME: Set the correct protocol version on GET /mcp
		assertThat(calls)
			.filteredOn(c -> !c.body().contains("\"method\":\"initialize\"") && c.method().equals(HttpMethod.POST))
			// POST notification/initialized ; POST tools/call
			.hasSize(2)
			.map(McpTestRequestRecordingExchangeFilterFunction.Call::headers)
			.allSatisfy(headers -> assertThat(headers).containsEntry("mcp-protocol-version",
					ProtocolVersions.MCP_2025_06_18));

		assertThat(response).isNotNull();
		assertThat(response.content()).hasSize(1)
			.first()
			.extracting(McpSchema.TextContent.class::cast)
			.extracting(McpSchema.TextContent::text)
			.isEqualTo(ProtocolVersions.MCP_2025_06_18);
		mcpServer.close();
	}

}
